{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 用增强学习玩车杆游戏"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "玩一局"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.\u001b[0m\n",
      "新一局游戏 初始观测 = [ 0.00974547  0.04064659  0.0486582  -0.04658112]\n",
      "0: 动作 = 0\n",
      "0: 观测 = [ 0.0105584  -0.15513809  0.04772658  0.26104819], 本步得分 = 1.0, 结束指示 = False, 其他信息 = {}\n",
      "1: 动作 = 1\n",
      "1: 观测 = [ 0.00745564  0.03927122  0.05294754 -0.01620744], 本步得分 = 1.0, 结束指示 = False, 其他信息 = {}\n",
      "2: 动作 = 1\n",
      "2: 观测 = [ 0.00824106  0.23359543  0.05262339 -0.29172597], 本步得分 = 1.0, 结束指示 = False, 其他信息 = {}\n",
      "3: 动作 = 0\n",
      "3: 观测 = [0.01291297 0.03776419 0.04678887 0.0170784 ], 本步得分 = 1.0, 结束指示 = False, 其他信息 = {}\n",
      "4: 动作 = 1\n",
      "4: 观测 = [ 0.01366825  0.23218499  0.04713044 -0.26048273], 本步得分 = 1.0, 结束指示 = False, 其他信息 = {}\n",
      "5: 动作 = 1\n",
      "5: 观测 = [ 0.01831195  0.42660357  0.04192079 -0.53793554], 本步得分 = 1.0, 结束指示 = False, 其他信息 = {}\n",
      "6: 动作 = 1\n",
      "6: 观测 = [ 0.02684402  0.62111185  0.03116208 -0.81712053], 本步得分 = 1.0, 结束指示 = False, 其他信息 = {}\n",
      "7: 动作 = 1\n",
      "7: 观测 = [ 0.03926626  0.81579365  0.01481966 -1.09984129], 本步得分 = 1.0, 结束指示 = False, 其他信息 = {}\n",
      "8: 动作 = 1\n",
      "8: 观测 = [ 0.05558213  1.01071745 -0.00717716 -1.38783806], 本步得分 = 1.0, 结束指示 = False, 其他信息 = {}\n",
      "9: 动作 = 1\n",
      "9: 观测 = [ 0.07579648  1.20592811 -0.03493392 -1.68275657], 本步得分 = 1.0, 结束指示 = False, 其他信息 = {}\n",
      "10: 动作 = 1\n",
      "10: 观测 = [ 0.09991505  1.40143672 -0.06858905 -1.98610904], 本步得分 = 1.0, 结束指示 = False, 其他信息 = {}\n",
      "11: 动作 = 0\n",
      "11: 观测 = [ 0.12794378  1.20709839 -0.10831123 -1.71543635], 本步得分 = 1.0, 结束指示 = False, 其他信息 = {}\n",
      "12: 动作 = 0\n",
      "12: 观测 = [ 0.15208575  1.01337295 -0.14261996 -1.4583323 ], 本步得分 = 1.0, 结束指示 = False, 其他信息 = {}\n",
      "13: 动作 = 1\n",
      "13: 观测 = [ 0.17235321  1.20992719 -0.17178661 -1.79195849], 本步得分 = 1.0, 结束指示 = False, 其他信息 = {}\n",
      "14: 动作 = 0\n",
      "14: 观测 = [ 0.19655175  1.01709749 -0.20762578 -1.55722859], 本步得分 = 1.0, 结束指示 = False, 其他信息 = {}\n",
      "15: 动作 = 0\n",
      "15: 观测 = [ 0.2168937   0.82497802 -0.23877035 -1.33584292], 本步得分 = 1.0, 结束指示 = True, 其他信息 = {}\n"
     ]
    }
   ],
   "source": [
    "import gym\n",
    "env = gym.make('CartPole-v0') # 获得游戏环境\n",
    "observation = env.reset() # 复位游戏环境,新一局游戏开始\n",
    "print ('新一局游戏 初始观测 = {}'.format(observation))\n",
    "for t in range(200):\n",
    "    env.render()\n",
    "    action = env.action_space.sample() # 随机选择动作\n",
    "    print ('{}: 动作 = {}'.format(t, action))\n",
    "    observation, reward, done, info = env.step(action) # 执行行为\n",
    "    print ('{}: 观测 = {}, 本步得分 = {}, 结束指示 = {}, 其他信息 = {}'.format(\n",
    "            t, observation, reward, done, info))\n",
    "    if done:\n",
    "        break\n",
    "env.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "玩多局"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.\u001b[0m\n",
      "第0局得分 = 14.0\n",
      "第1局得分 = 19.0\n",
      "第2局得分 = 14.0\n",
      "第3局得分 = 12.0\n",
      "第4局得分 = 16.0\n",
      "第5局得分 = 25.0\n",
      "第6局得分 = 13.0\n",
      "第7局得分 = 41.0\n",
      "第8局得分 = 22.0\n",
      "第9局得分 = 13.0\n",
      "第10局得分 = 19.0\n",
      "第11局得分 = 20.0\n",
      "第12局得分 = 16.0\n",
      "第13局得分 = 21.0\n",
      "第14局得分 = 26.0\n",
      "第15局得分 = 12.0\n",
      "第16局得分 = 12.0\n",
      "第17局得分 = 12.0\n",
      "第18局得分 = 19.0\n",
      "第19局得分 = 58.0\n"
     ]
    }
   ],
   "source": [
    "import gym\n",
    "env = gym.make('CartPole-v0')\n",
    "n_episode = 20\n",
    "for i_episode in range(n_episode):\n",
    "    observation = env.reset()\n",
    "    episode_reward = 0\n",
    "    while True:\n",
    "        # env.render()\n",
    "        action = env.action_space.sample() # 随机选\n",
    "        observation, reward, done, _ = env.step(action)\n",
    "        episode_reward += reward\n",
    "        state = observation\n",
    "        if done:\n",
    "            break\n",
    "    print ('第{}局得分 = {}'.format(i_episode, episode_reward))\n",
    "env.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Cart Pole Environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "import gym\n",
    "env = gym.make('CartPole-v0')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "搭建 DQN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "model = nn.Sequential(\n",
    "        nn.Linear(env.observation_space.shape[0], 128),\n",
    "        nn.ReLU(),\n",
    "        nn.Linear(128, 128),\n",
    "        nn.ReLU(),\n",
    "        nn.Linear(128, env.action_space.n)\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "def act(model, state, epsilon):\n",
    "    if random.random() > epsilon: # 选最大的\n",
    "        state = torch.FloatTensor(state).unsqueeze(0)\n",
    "        q_value = model.forward(state)\n",
    "        action = q_value.max(1)[1].item()\n",
    "    else: # 随便选\n",
    "        action = random.randrange(env.action_space.n)\n",
    "    return action"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# epsilon值不断下降\n",
    "import math\n",
    "def calc_epsilon(t, epsilon_start=1.0,\n",
    "        epsilon_final=0.01, epsilon_decay=500):\n",
    "    epsilon = epsilon_final + (epsilon_start - epsilon_final) \\\n",
    "            * math.exp(-1. * t / epsilon_decay)\n",
    "    return epsilon"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 最近历史缓存\n",
    "import numpy as np\n",
    "from collections import deque\n",
    "\n",
    "batch_size = 32\n",
    "\n",
    "class ReplayBuffer:\n",
    "    def __init__(self, capacity):\n",
    "        self.buffer = deque(maxlen=capacity)\n",
    "    \n",
    "    def push(self, state, action, reward, next_state, done):\n",
    "        state = np.expand_dims(state, 0)\n",
    "        next_state = np.expand_dims(next_state, 0)\n",
    "        self.buffer.append((state, action, reward, next_state, done))\n",
    "    \n",
    "    def sample(self, batch_size):\n",
    "        state, action, reward, next_state, done = zip( \\\n",
    "                *random.sample(self.buffer, batch_size))\n",
    "        concat_state = np.concatenate(state)\n",
    "        concat_next_state = np.concatenate(next_state)\n",
    "        return concat_state, action, reward, concat_next_state, done\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.buffer)\n",
    "\n",
    "replay_buffer = ReplayBuffer(1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "第0局收益 = 29.0\n",
      "第1局收益 = 21.0\n",
      "第2局收益 = 21.0\n",
      "第3局收益 = 15.0\n",
      "第4局收益 = 32.0\n",
      "第5局收益 = 15.0\n",
      "第6局收益 = 14.0\n",
      "第7局收益 = 42.0\n",
      "第8局收益 = 12.0\n",
      "第9局收益 = 15.0\n",
      "第10局收益 = 10.0\n",
      "第11局收益 = 44.0\n",
      "第12局收益 = 20.0\n",
      "第13局收益 = 24.0\n",
      "第14局收益 = 16.0\n",
      "第15局收益 = 25.0\n",
      "第16局收益 = 10.0\n",
      "第17局收益 = 15.0\n",
      "第18局收益 = 17.0\n",
      "第19局收益 = 53.0\n",
      "第20局收益 = 33.0\n",
      "第21局收益 = 20.0\n",
      "第22局收益 = 31.0\n",
      "第23局收益 = 22.0\n",
      "第24局收益 = 13.0\n",
      "第25局收益 = 27.0\n",
      "第26局收益 = 48.0\n",
      "第27局收益 = 23.0\n",
      "第28局收益 = 48.0\n",
      "第29局收益 = 23.0\n",
      "第30局收益 = 32.0\n",
      "第31局收益 = 24.0\n",
      "第32局收益 = 65.0\n",
      "第33局收益 = 29.0\n",
      "第34局收益 = 25.0\n",
      "第35局收益 = 63.0\n",
      "第36局收益 = 32.0\n",
      "第37局收益 = 65.0\n",
      "第38局收益 = 35.0\n",
      "第39局收益 = 63.0\n",
      "第40局收益 = 54.0\n",
      "第41局收益 = 101.0\n",
      "第42局收益 = 75.0\n",
      "第43局收益 = 139.0\n",
      "第44局收益 = 70.0\n",
      "第45局收益 = 79.0\n",
      "第46局收益 = 129.0\n",
      "第47局收益 = 122.0\n",
      "第48局收益 = 95.0\n",
      "第49局收益 = 140.0\n",
      "第50局收益 = 140.0\n",
      "第51局收益 = 127.0\n",
      "第52局收益 = 171.0\n",
      "第53局收益 = 200.0\n",
      "第54局收益 = 128.0\n",
      "第55局收益 = 200.0\n",
      "第56局收益 = 200.0\n",
      "第57局收益 = 200.0\n",
      "第58局收益 = 200.0\n",
      "第59局收益 = 184.0\n",
      "第60局收益 = 200.0\n",
      "第61局收益 = 200.0\n",
      "第62局收益 = 200.0\n",
      "第63局收益 = 200.0\n",
      "第64局收益 = 199.0\n",
      "第65局收益 = 200.0\n",
      "第66局收益 = 200.0\n",
      "第67局收益 = 200.0\n",
      "第68局收益 = 200.0\n",
      "第69局收益 = 200.0\n",
      "第70局收益 = 200.0\n",
      "第71局收益 = 200.0\n",
      "第72局收益 = 200.0\n"
     ]
    }
   ],
   "source": [
    "import torch.optim\n",
    "optimizer = torch.optim.Adam(model.parameters())\n",
    "\n",
    "gamma = 0.99\n",
    "\n",
    "episode_rewards = [] # 各局得分,用来判断训练是否完成\n",
    "t = 0 # 训练步数,用于计算epsilon\n",
    "\n",
    "while True:\n",
    "    \n",
    "    # 开始新的一局\n",
    "    state = env.reset()\n",
    "    episode_reward = 0\n",
    "\n",
    "    while True:\n",
    "        epsilon = calc_epsilon(t)\n",
    "        action = act(model, state, epsilon)\n",
    "        next_state, reward, done, _ = env.step(action)\n",
    "        replay_buffer.push(state, action, reward, next_state, done)\n",
    "\n",
    "        state = next_state\n",
    "        episode_reward += reward\n",
    "\n",
    "        if len(replay_buffer) > batch_size:\n",
    "\n",
    "            # 计算时间差分误差\n",
    "            sample_state, sample_action, sample_reward, sample_next_state, \\\n",
    "                    sample_done = replay_buffer.sample(batch_size)\n",
    "\n",
    "            sample_state = torch.tensor(sample_state, dtype=torch.float32)\n",
    "            sample_action = torch.tensor(sample_action, dtype=torch.int64)\n",
    "            sample_reward = torch.tensor(sample_reward, dtype=torch.float32)\n",
    "            sample_next_state = torch.tensor(sample_next_state,\n",
    "                    dtype=torch.float32)\n",
    "            sample_done = torch.tensor(sample_done, dtype=torch.float32)\n",
    "            \n",
    "            next_qs = model(sample_next_state)\n",
    "            next_q, _ = next_qs.max(1)\n",
    "            expected_q = sample_reward + gamma * next_q * (1 - sample_done)\n",
    "            \n",
    "            qs = model(sample_state)\n",
    "            q = qs.gather(1, sample_action.unsqueeze(1)).squeeze(1)\n",
    "            \n",
    "            td_error = expected_q - q\n",
    "            \n",
    "            # 计算 MSE 损失\n",
    "            loss = td_error.pow(2).mean() \n",
    "            \n",
    "            # 根据损失改进网络\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            \n",
    "            t += 1\n",
    "            \n",
    "        if done: # 本局结束\n",
    "            i_episode = len(episode_rewards)\n",
    "            print ('第{}局收益 = {}'.format(i_episode, episode_reward))\n",
    "            episode_rewards.append(episode_reward)\n",
    "            break\n",
    "            \n",
    "    if len(episode_rewards) > 20 and np.mean(episode_rewards[-20:]) > 195:\n",
    "        break # 训练结束"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "使用\n",
    "（固定 $\\epsilon$ 的值为0）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "第0局得分 = 200.0\n",
      "第1局得分 = 200.0\n",
      "第2局得分 = 200.0\n",
      "第3局得分 = 200.0\n",
      "第4局得分 = 200.0\n",
      "第5局得分 = 200.0\n",
      "第6局得分 = 200.0\n",
      "第7局得分 = 200.0\n",
      "第8局得分 = 200.0\n",
      "第9局得分 = 200.0\n",
      "第10局得分 = 200.0\n",
      "第11局得分 = 200.0\n",
      "第12局得分 = 200.0\n",
      "第13局得分 = 200.0\n",
      "第14局得分 = 200.0\n",
      "第15局得分 = 200.0\n",
      "第16局得分 = 200.0\n",
      "第17局得分 = 200.0\n",
      "第18局得分 = 200.0\n",
      "第19局得分 = 200.0\n"
     ]
    }
   ],
   "source": [
    "n_episode = 20\n",
    "for i_episode in range(n_episode):\n",
    "    observation = env.reset()\n",
    "    episode_reward = 0\n",
    "    while True:\n",
    "        # env.render()\n",
    "        action = act(model, observation, 0)\n",
    "        observation, reward, done, _ = env.step(action)\n",
    "        episode_reward += reward\n",
    "        state = observation\n",
    "        if done:\n",
    "            break\n",
    "    print ('第{}局得分 = {}'.format(i_episode, episode_reward))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
